{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PE8tTD1OMJCM"
      },
      "outputs": [],
      "source": [
        "!pip install transformers\n",
        "!pip install captum"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import DistilBertForSequenceClassification, DistilBertTokenizer\n",
        "from captum.attr import LayerIntegratedGradients, visualization as viz\n",
        "import torch\n",
        "\n",
        "def visualize_sentiment(text: str):\n",
        "    \"\"\"\n",
        "    Visualizes the sentiment of the given text using a pre-trained DistilBERT model.\n",
        "\n",
        "    Args:\n",
        "        text (str): The text to visualize.\n",
        "    \"\"\"\n",
        "\n",
        "    # Pre-trained model and tokenizer\n",
        "    model_path = 'distilbert-base-uncased-finetuned-sst-2-english'\n",
        "    model = DistilBertForSequenceClassification.from_pretrained(model_path)\n",
        "    tokenizer = DistilBertTokenizer.from_pretrained(model_path)\n",
        "    model.eval()\n",
        "\n",
        "    # Function to create input tensors and baseline for the given text\n",
        "    def construct_input_and_baseline(input_text: str):\n",
        "        \"\"\"Constructs input and baseline tensors for the given text.\"\"\"\n",
        "        max_length = 768\n",
        "        baseline_token_id = tokenizer.pad_token_id\n",
        "        sep_token_id = tokenizer.sep_token_id\n",
        "        cls_token_id = tokenizer.cls_token_id\n",
        "\n",
        "        text_ids = tokenizer.encode(input_text, max_length=max_length, truncation=True, add_special_tokens=False)\n",
        "        input_ids = [cls_token_id] + text_ids + [sep_token_id]\n",
        "        baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]\n",
        "        token_list = tokenizer.convert_ids_to_tokens(input_ids)\n",
        "\n",
        "        return torch.tensor([input_ids], device='cpu'), torch.tensor([baseline_input_ids], device='cpu'), token_list\n",
        "\n",
        "    # Constructing input and baseline\n",
        "    input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)\n",
        "\n",
        "    # Defining a function for the model's output\n",
        "    def model_output(inputs):\n",
        "        return model(inputs)[0]\n",
        "\n",
        "    # Layer Integrated Gradients\n",
        "    lig = LayerIntegratedGradients(model_output, model.distilbert.embeddings)\n",
        "\n",
        "    # Target classes\n",
        "    target_classes = [0, 1]\n",
        "    attributions = {}\n",
        "    delta = {}\n",
        "\n",
        "    # Calculating attributions for both classes\n",
        "    for target_class in target_classes:\n",
        "        attributions[target_class], delta[target_class] = lig.attribute(\n",
        "            inputs=input_ids,\n",
        "            baselines=baseline_input_ids,\n",
        "            target=target_class,\n",
        "            return_convergence_delta=True,\n",
        "            internal_batch_size=1)\n",
        "\n",
        "    # Summarizing attributions\n",
        "    neg_attributions = attributions[0].sum(dim=-1).squeeze(0) / torch.norm(attributions[0])\n",
        "    pos_attributions = attributions[1].sum(dim=-1).squeeze(0) / torch.norm(attributions[1])\n",
        "\n",
        "    # Predicting the class\n",
        "    pred_prob, pred_class = torch.max(model(input_ids)[0]), int(torch.argmax(model(input_ids)[0]))\n",
        "\n",
        "    # Selecting the attributions based on the predicted class\n",
        "    summarized_attr = pos_attributions if pred_class == 1 else neg_attributions\n",
        "\n",
        "    # Visualization data\n",
        "    score_vis = viz.VisualizationDataRecord(\n",
        "                        word_attributions=summarized_attr,\n",
        "                        pred_prob=pred_prob,\n",
        "                        pred_class=pred_class,\n",
        "                        true_class=None,\n",
        "                        attr_class=text,\n",
        "                        attr_score=summarized_attr.sum(),\n",
        "                        raw_input_ids=all_tokens,\n",
        "                        convergence_score=delta[pred_class])\n",
        "\n",
        "    # Visualizing the result\n",
        "    viz.visualize_text([score_vis])\n"
      ],
      "metadata": {
        "id": "vThfBIUdT5mq"
      },
      "execution_count": 58,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "text = \"The movie was not bad as mentioned by critics. It was in fact awesome; I enjoyed the whole time\"\n",
        "visualize_sentiment(text)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 141
        },
        "id": "VlYR7OK9WbxC",
        "outputId": "f65c9b19-eb87-4d89-966f-3d2fa62b4e38"
      },
      "execution_count": 60,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>None</b></text></td><td><text style=\"padding-right:2em\"><b>1 (4.65)</b></text></td><td><text style=\"padding-right:2em\"><b>The movie was not bad as mentioned by critics. It was in fact awesome; I enjoyed the whole time</b></text></td><td><text style=\"padding-right:2em\"><b>12.92</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [CLS]                    </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(120, 75%, 81%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> movie                    </font></mark><mark style=\"background-color: hsl(120, 75%, 64%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> not                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> bad                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> as                    </font></mark><mark style=\"background-color: hsl(120, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> mentioned                    </font></mark><mark style=\"background-color: hsl(120, 75%, 80%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> by                    </font></mark><mark style=\"background-color: hsl(120, 75%, 80%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> critics                    </font></mark><mark style=\"background-color: hsl(120, 75%, 74%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(120, 75%, 80%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> it                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> was                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> in                    </font></mark><mark style=\"background-color: hsl(120, 75%, 69%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> fact                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> awesome                    </font></mark><mark style=\"background-color: hsl(0, 75%, 71%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ;                    </font></mark><mark style=\"background-color: hsl(120, 75%, 81%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> i                    </font></mark><mark style=\"background-color: hsl(120, 75%, 50%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> enjoyed                    </font></mark><mark style=\"background-color: hsl(0, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> the                    </font></mark><mark style=\"background-color: hsl(0, 75%, 71%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> whole                    </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> time                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> [SEP]                    </font></mark></td><tr></table>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "49m8hDLlWd8e"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}